"""
Description
-----------
Scripts to determine the robust parameter and uv tapering requred to achieve a
circular beam. Note that this version doesn't take into account the different
channel spacings which may be present in a MS made up of multiple observations.

Usage
-----
The main function is get_robust_and_taper(), which can be imported into any
workflow. First you must activate the exoALMA virtual environment with

> source /lustre/cv/projects/exoALMA/casa_modular_6.2.1/bin/activate

If you want to call this function from the command line you can use

> python circularize_beam.py -ms path/to/file.ms -beam 0.3 -npix 1028 -dpix 0.02

where -base_ms describes the path to the measurement set, -beam is the desired
(circular) beam size, -npix describes the number of pixels in the resulting
image and -dpix is the pixel scale in [arcsec]. This will return the optimal
robust value to use, along with the major axis, minor axis (both in units of
[arcsec]) and position angle in [degrees] of the Gaussian uv-taper to apply.

If you use the functions as part of a workflow, you can use:

> results = get_robust_and_taper(*args, **kwargs)
> robust, uvtaper = parse_robust_and_taper(results)

where the robust will be returned as a float and the uvtaper as a list of
strings which can be provided directly to ``tclean``.

To Do
-----
- Check optimization methods.
- Check the appropriate window size for PSF fitting.
- Account for proper channel weighting.

Authors
-------
Richad Teague (MIT)
Ryan Loomis (NRAO)
John Ilee (Leeds)
"""

import argparse
import casatools
import numpy as np
from numba import njit
from numpy.fft import fftshift, fft2
from scipy.optimize import minimize
import scipy.constants as sc
np.seterr(all='ignore')

dx_to_FWHM = 2.0 * np.sqrt(2.0 * np.log(2.0))


def gaussian2D(params, npix):
    """
    Returns a 2D Gaussian centered in an array with shape ``(npix, npix)``, and
    with an amplitude of 1.

    Args:
        params (tuple): Gaussian major axis in [pix], minor axis in [pix] and
            rotation angle in [deg].
        npix (int): Length of the side of Gaussian to make such that the
            resulting shape is ``(npix, npix)``.

    Returns:
        gaussian (array): 2D array with shape ``(npix, npix)``.
    """
    assert len(params) == 3, "params = dx, dy, theta"
    theta = np.radians(90.0 - params[2])
    x, y = np.indices((npix, npix)) - 0.5 * (npix - 1.0)
    xr = x * np.cos(theta) - y * np.sin(theta)
    yr = x * np.sin(theta) + y * np.cos(theta)
    return np.exp(-0.5 * ((xr / params[0])**2 + (yr / params[1])**2))


def beam_chi2(params, psf):
    """
    Calculate the chi-squared of the Gaussian model for the provided PSF.

    Args:
        params (tuple): Gaussian major axis, minor axis and rotation angle.
        psf (array): PSF to compare to. Must be a square, 2D array.

    Returns:
        chi2 (float): The chi-squared value.
    """
    assert psf.ndim == 2, "PSF must be 2D"
    assert psf.shape[0] == psf.shape[1], "PSF must be square"
    return np.nansum((gaussian2D(params, psf.shape[0]) - psf)**2)


def taper_chi2(taper_params, desired_size, gridded_weights, npix, dpix):
    """
    Calculate the chi-squared of a synthesized circularized beam attained
    through uv tapering compared to a circular beam with size ``desired_size``.
    Note that unlike ``beam_chi2()`` which compares to the PSF array, this just 
    compares to the major and minor axes of the Gaussian fit.

    Args:
        taper_params (tuple): Gaussian major axis in [arcsec], minor axis in
            [arcsec] and rotation angle in [degrees] of the Gaussian.
        desired_size (float): Desired beam size in [arcsec].
        gridded_weights (array): Gridded weights.
        npix (int): Number of pixels per side of the PSF.
        dpix (float): Pixel size in [arcsec/pix].

    Returns:
        chi2 (float): Sum of the chi-squared for the major and minor axes of the
            synthesized beam relative to ``desired_size``.
    """
    tapered_beam = get_tapered_beam(gridded_weights=gridded_weights,
                                    taper_params=taper_params,
                                    npix=npix,
                                    dpix=dpix)
    bmaj, bmin = get_beam_parameters(tapered_beam)[:2]
    chi2_major = (bmaj * dpix - desired_size)**2
    chi2_minor = (bmin * dpix - desired_size)**2
    return chi2_major + chi2_minor


def optimize_taper(desired_size, robust, weights, npix, dpix, uu, vv,
                   optimize_kwargs=None):
    """
    Use scipy.optimize.minimize to derive the Gaussian uv-taper which will
    result in a circular beam.

    Args:
        desired_size (float): Desired beam size in [arcsec].
        robust (float): Robust weighting parameter.
        weights (array): Ungridded weights.
        npix (int): Number of pixels per side of the PSF.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        taper_major (float): Gaussian major axis in [arcsec] for the uv-taper.
        taper_minor (float): Gaussian minor axis in [arcsec] for the uv-taper.
        taper_pa (float): Gaussian position angle in [degrees] for the uv-taper.
    """
    print("Optimizing uv taper...")
    
    # Generate the robust weights for the minimization.

    gridded_robust_weights = get_gridded_robust_weights(robust=robust,
                                                        weights=weights,
                                                        npix=npix,
                                                        dpix=dpix,
                                                        uu=uu,
                                                        vv=vv)

    # Define the kwargs for the optimization.

    kw = {} if optimize_kwargs is None else optimize_kwargs
    kw['method'] = kw.pop('method', "L-BFGS-B")
    kw['bounds'] = kw.pop('bounds', [(0.0, 10.0), (0.0, 10.0), (0.0, 360.0)])

    # Run the optimization.

    p0 = [desired_size / 2.0, desired_size / 3.0, 90.0]
    args = (desired_size, gridded_robust_weights, npix, dpix)
    res = minimize(taper_chi2, p0, args=args, **kw)
    if not res.success:
        print("WARNING: Unable to optimize uv-taper parameters.")
        if kw['method'] != 'Nelder-Mead':
            print("\t Trying with different minimization method...")
            kw['method'] = 'Nelder-Mead'
            kw['bounds'] = None
            res = minimize(taper_chi2, p0, args=args, **kw)
            if not res.success:
                print("WARNING: Unable to optimize uv-taper parameters.")
            else:
                print("Optimal uv taper found.")
    else:
        print("Optimal uv taper found.")
    
    tapered_beam = get_tapered_beam(gridded_weights=gridded_robust_weights,
                                    taper_params=res.x,
                                    npix=npix,
                                    dpix=dpix)
    bmaj, bmin = get_beam_parameters(tapered_beam)[:2]
    print('Resulting beam has major axis of {:.3f}".'.format(bmaj * dpix))
    print('Resulting beam has minor axis of {:.3f}".'.format(bmin * dpix))

    # Return the best-fit parameters, even if the minimization fails. Note that
    # the tapers need to be in [arcsec] and the position angle in [degrees].

    taper_major, taper_minor, taper_pa = res.x
    return taper_major, taper_minor, taper_pa % 180.0, bmaj


def get_beam_parameters(psf, window=20, threshold=0.35):
    """
    Fit a 2D Gaussian to the provided PSF to derive the beam major axis, minor
    axis and position angle. If there are multiple channels, the median PSF
    will be fit. Note that PSFs that consider full polarization (i.e., have a
    Stokes axis), will be ignored.
    
    Args:
        psf (array): A 2D or 3D array describing the PSF. If multiple channels
            are considered, this is assumed to be the final axis. 
        window (optional[int]): Size of window in [pix] to fit.
        threshold (optional[float]): Minimum value to consider part of the PSF.
            A threshold of 0.35 is the default by CASA.

    Returns:
        major (float): Beam major axis in [pix].
        minor (float): Beam minor axis in [pix].
        phi (float): Beam position angle in [deg].
    """

    # Ensure the data is square. Check if there are multiple PSFs (multiple
    # channels), or a single channel. Note that this doesn't take into account
    # full polarization. If this is the case, a ValueError is thrown.

    psf = np.squeeze(psf)
    assert psf.shape[0] == psf.shape[1], "PSF must be square."
    if psf.ndim == 2:
        psf = np.expand_dims(psf, axis=2)
    assert psf.ndim == 3, "Cannot handle full polarization observations."

    # Take the median PSF and fit a Gaussian to it. Over narrow bandwidths this
    # should be a relatively safe thing to do. Window out the PSF and mask
    # values below ``threshold``.

    low_idx = int(psf.shape[0] / 2 - window / 2)
    high_idx = int(psf.shape[0] / 2 + window / 2) + 1
    psf_window = np.median(psf, axis=2)[low_idx:high_idx, low_idx:high_idx]
    psf_thresh = np.where(psf_window > threshold, psf_window, np.nan)

    # Fit a 2D Gaussian to the windowed and thresholded PSF. Note that the
    # axes sizes are in [pix] and so much be converted to [arcsec].

    p0 = [window / 4.0, window / 4.0, 0.0]
    bounds = ((0.0, window / 2), (0.0, window / 2), (-180.0, 180.0))
    res = minimize(beam_chi2, x0=p0, args=(psf_thresh), bounds=bounds)

    # Always order params as (major, minor, phi).

    bmaj, bmin = sorted(res.x[:2] * dx_to_FWHM)[::-1]
    return bmaj, bmin, res.x[2]


@njit(fastmath=True)
def grid_weights(togrid, npix, dpix, uu, vv):
    """
    Grid the weights.

    Args:
        togrid (array): Weights to grid.
        npix (int): Number of pixels in the desired image.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        grided (array): Gridded weights.
    """
    idx = npix / 2.0 + 0.5
    dd = 1.0 / npix / dpix / sc.arcsec
    grided = np.zeros((npix, npix))
    for i in np.arange(uu.size):
        grided[int(idx + uu[i] / dd), int(idx + vv[i] / dd)] += togrid[i]
        grided[int(idx - uu[i] / dd), int(idx - vv[i] / dd)] += togrid[i]
    return grided


@njit(fastmath=True)
def ungrid_weights(toungrid, npix, dpix, uu, vv):
    """
    Ungrid the weights.

    Args:
        gridded (array): Gridded weights.
        npix (int): Number of pixels in the desired image.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        ungridded (array): Ungridded weights.
    """
    idx = npix / 2.0 + 0.5
    dd = 1.0 / npix / dpix / sc.arcsec
    ungridded = np.zeros(uu.size)
    for i in np.arange(uu.size):
        ungridded[i] = toungrid[int(idx + uu[i] / dd), int(idx + vv[i] / dd)]
    return ungridded


def get_beam(gridded_weights):
    """
    Return a synthesized beam.

    Args:
        gridded_weights (array): Gridded weights.

    Retuns:
        beam (array): Synthesized beam.
    """
    beam = np.real(fftshift(fft2(fftshift(gridded_weights))))
    return beam / np.max(beam)


def get_tapered_beam(gridded_weights, taper_params, npix, dpix):
    """
    Return a PSF after applying a Gaussian uv-taper to the gridded weights.

    Args:
        gridded_weights (array): Gridded weights.
        taper_params (tuple): Major axis in [arcsec], minor axis in [arcsec] and
            position angle in [degrees] of the Gaussian taper to apply to the
            gridded weights.
        npix (int): Number of pixels in the image.
        dpix (float): Pixels scale in [arcsec/pix].

    Returns:
        tapered_beam (array): Synthesized beam after uv-tapering.
    """
    tmajor = npix * dpix / taper_params[0] / dx_to_FWHM
    tminor = npix * dpix / taper_params[1] / dx_to_FWHM
    gauss_taper = gaussian2D((tmajor, tminor, taper_params[2]), npix)
    assert gridded_weights.shape == gauss_taper.shape
    return get_beam(gridded_weights=gridded_weights * gauss_taper)


def get_gridded_robust_weights(robust, weights, npix, dpix, uu, vv):
    """
    Return the gridded robust weights.

    Args:
        robust (float): Robust weight to use, ranging from -2 for uniform
            weighting to +2 for natural weighting.
        weights (array): Ungridded weights.
        npix (int): Number of pixels in the desired image.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        gridded_robust_weights (array): Gridded robust weights.
    """

    # Check robust value is OK.

    if abs(robust) > 2.0:
        raise ValueError("robust value must be -2 <= robust <= +2.")

    # Calculate the robust weighting.

    gridded_weights = grid_weights(togrid=weights,
                                   npix=npix,
                                   dpix=dpix,
                                   uu=uu,
                                   vv=vv)

    f_sq = (5.0 * np.power(10, -robust))**2 
    f_sq *= np.sum(weights) / np.sum(gridded_weights**2)
    gridded_robust_weights = 1.0 / (1.0 + gridded_weights * f_sq)

    robust_weights = ungrid_weights(toungrid=gridded_robust_weights,
                                    npix=npix,
                                    dpix=dpix,
                                    uu=uu,
                                    vv=vv)

    # Apply the robust weighting to the original weights and return.

    gridded_robust_weights = grid_weights(togrid=robust_weights*weights,
                                          npix=npix,
                                          dpix=dpix,
                                          uu=uu,
                                          vv=vv)
                                         
    return gridded_robust_weights


def get_robust_beam(robust, weights, npix, dpix, uu, vv):
    """
    Return a PSF given a specific robust parameter.

    Args:
        robust (float): Robust weight to use, ranging from -2 for uniform
            weighting to +2 for natural weighting.
        weights (array): Ungridded weights.
        npix (int): Number of pixels in the desired image.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        robust_beam (array): Synthesized beam after robust weighting.
    """
    gridded_robust_weights = get_gridded_robust_weights(robust=robust,
                                                        weights=weights,
                                                        npix=npix,
                                                        dpix=dpix,
                                                        uu=uu,
                                                        vv=vv)
    return get_beam(gridded_weights=gridded_robust_weights)


def get_tapered_beam_parameters(gridded_weights, taper_params, npix, dpix):
    """
    Returns the major axis, minor axis and position angle of the synthesized
    beam made with a Gaussian uv taper.

    Args:
        gridded_weights (array): Gridded weights.
        taper_params (tuple): Major axis in [arcsec], minor axis in [arcsec] and
            position angle in [degrees] of the Gaussian taper to apply to the
            gridded weights.
        npix (int): Number of pixels in the image.
        dpix (float): Pixels scale in [arcsec/pixel].
    """
    tapered_beam = get_tapered_beam(gridded_weights=gridded_weights,
                                    taper_params=taper_params,
                                    npix=npix,
                                    dpix=dpix)
    bmaj, bmin, bpa = get_beam_parameters(tapered_beam)
    return bmaj * dpix, bmin * dpix, bpa


def get_robust_beam_parameters(robust, weights, npix, dpix, uu, vv):
    """
    Returns the major axis, minor axis and position angle of the synthesized
    beam made with a specific robust value.

    Args:
        robust (float): Robust weight to use, ranging from -2 for uniform
            weighting to +2 for natural weighting.
        weights (array): Ungridded weights.
        npix (int): Nuber of pixels in the image.
        dpix (float): Pixel size in [arcsec/pix].
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].

    Returns:
        bmaj (float): Beam major axis in [arcsec].
        bmin (float): Beam minor axis in [arcsec].
        bpa (float): Beam position angle in [degrees].
    """
    robust_beam = get_robust_beam(robust=robust,
                                  weights=weights,
                                  npix=npix,
                                  dpix=dpix,
                                  uu=uu,
                                  vv=vv)
    bmaj, bmin, bpa = get_beam_parameters(robust_beam)
    return bmaj * dpix, bmin * dpix, bpa


def get_visibilities(base_ms):
    """
    Read in the visibilities from the provided MS.

    Args:
        base_ms (str): Path to the desired measurement set.

    Returns:
        weights (array): Ungridded weights.
        uu (array): Spatial frequencies in [lambda].
        vv (array): Spatial frequencies in [lambda].
    """

    # Remove trailing slashes.

    if base_ms[-1] == '/':
        base_ms = base_ms[:-1]

    # Grab the channel frequencies for all the SPWs in the MS.
    
    tb = casatools.table()
    tb.open(base_ms + '/SPECTRAL_WINDOW')
    ref_freq = tb.getcol("REF_FREQUENCY")
    chan_freq = [tb.getcell("CHAN_FREQ", rownr=i)
                 for i in range(ref_freq.shape[0])]
    chan_freq = np.concatenate(chan_freq)
    tb.close()

    # Collapse the channel frequencies to the median channel frequency. This is
    # not great if there's a large variation in channel widths in the MS.

    if np.std(chan_freq) / np.median(chan_freq) > 0.01:
        print("WARNING: > 1% variation in channel frequencies.")
    chan_freq = np.median(chan_freq)

    # Use CASA table tools to get columns of UVW, DATA, WEIGHT, etc.

    tb.open(base_ms)
    flag = tb.getcol("FLAG")         
    uvw = tb.getcol("UVW")             
    weights = tb.getcol("WEIGHT")     
    weights = np.sum(weights, axis=0)  
    ant1 = tb.getcol("ANTENNA1")      
    ant2 = tb.getcol("ANTENNA2") 
    tb.close()

    # Break out the u, v spatial frequencies and convert from m to lambda.
    # Resulting shape will be (nvis,) for uu and vv.

    uu, vv, _ = uvw * chan_freq / sc.c

    # Toss out the autocorrelation placeholders.

    xc = np.where(ant1 != ant2)[0]
    uu, vv, weights = uu[xc], vv[xc], weights[xc]

    # Apply the flags and return.

    flag = np.logical_not(np.any(flag, axis=0))[0][xc]
    uu, vv, weights = uu[flag], vv[flag], weights[flag]
    return weights, uu, vv


def parse_robust_and_taper(results):
    """
    Parses the best-fit uv-taper and returns a string acceptable by CASA.

    Args:
        results (tuple): Results from ``get_robust_and_taper()``.
    
    Returns:
        robust (float): Briggs robust value.
        uvtaper (str): uv-taper argument for ``tclean()``.
    """
    robust, bmaj, bmin, bpa = results
    uvtaper = '{:.2f}arcsec, {:.2f}arcsec, {:.1f}degrees'
    return robust, uvtaper.format(bmaj, bmin, bpa).split(', ')


def get_robust_and_taper(base_ms, beam, npix, dpix, robust_step=0.1, 
                         tolerance=0.9):
    """
    Calculate the robust value and Gaussian uv-taper required to get a circular
    synthesized beam.

    Args:
        base_ms (str): Path to the desired measurement set.
        beam (float): Desired beam size in [arcsec].
        npix (int): Nuber of pixels in the image.
        dpix (float): Pixel size in [arcsec/pix].
        robust_step (optional[float]): Step to use to find the optimial robust
            value. A larger value is quicker, but a smaller value will give a
            better taper.
        tolerance (optional[float]): Shrink the desired ``beam`` value by this
            factor to ensure ``imsmooth`` can be used to do the final
            circularization.

    Returns:
        robust (float): Robust value for the imaging.
        taper_major (float): Gaussian major axis in [arcsec] for the uv-taper.
        taper_minor (float): Gaussian minor axis in [arcsec] for the uv-taper.
        taper_pa (float): Gaussian position angle in [degrees] for the uv-taper.

    """
    
    # Read in the visibilites.

    print("Reading in the visibilites...")
    weights, uu, vv = get_visibilities(base_ms=base_ms)

    # Make a check here to see if the desired beam is achievable with uniform
    # weight, because if not we don't need to both cycling through the values.

    uniform_bmaj = get_robust_beam_parameters(robust=-2.0,
                                              weights=weights,
                                              npix=int(npix),
                                              dpix=dpix,
                                              uu=uu,
                                              vv=vv)[0]
    if uniform_bmaj > beam * tolerance:
        raise ValueError("Desired beam unattainable with uniform weighting.")

    # Cycle through the different robust weights.

    robust_range = np.arange(-2.0, 2.001, robust_step)[::-1]
    robust_range = np.round(robust_range, 2)
    message = '\trobust = {:.2f} gives a major axis of {:.3f}".'

    print("Finding optimal robust value...")
    for robust in robust_range:
        bmaj = get_robust_beam_parameters(robust=robust,
                                          weights=weights,
                                          npix=int(npix),
                                          dpix=dpix,
                                          uu=uu,
                                          vv=vv)[0]
        print(message.format(robust, bmaj))
        if bmaj < beam * tolerance:
            break
    print("Optimal robust parameter found.")

    # Using this robust value, find the optimal uv-taper to cicularize the beam.

    taper_params = optimize_taper(desired_size=beam*tolerance,
                                    robust=robust,
                                    weights=weights,
                                    npix=int(npix),
                                    dpix=dpix,
                                    uu=uu,
                                    vv=vv)

    # Return the robust value and optimized taper parameters.
    # TODO: Check why the position angle is flipped here...

    return robust, taper_params[0], taper_params[1], -taper_params[2]


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-ms', type=str, help='path to MS')
    parser.add_argument('-beam', type=float, help='desired beam size')
    parser.add_argument('-dpix', type=float, help='pixel scale in arcsec')
    parser.add_argument('-npix', type=int, help='number of pixels')
    args = parser.parse_args()

    robust, bmaj, bmin, bpa = get_robust_and_taper(base_ms=args.ms,
                                                   beam=args.beam,
                                                   npix=args.npix,
                                                   dpix=args.dpix)
    
    print('Optimal robust value = {:.3f}.'.format(robust))
    print('Optimal uv-taper major axis = {:.3f}".'.format(bmaj))
    print('Optimal uv-taper minor axis = {:.3f}".'.format(bmin))
    print('Optimal uv-taper position angle = {:.1f} degrees.'.format(bpa))
